{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BGE-VL-v1&v1.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we will go through the multimodel retrieval models BGE-VL series, which achieved state-of-the-art performance on four popular zero-shot composed image retrieval benchmarks and the massive multimodal embedding benchmark (MMEB)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Installation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the required packages in your environment.\n",
    "\n",
    "- Our model works well on transformers==4.45.2, and we recommend using this version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install numpy torch transformers pillow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. BGE-VL-CLIP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |   Model Size   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
    "| [BAAI/bge-vl-base](https://huggingface.co/BAAI/BGE-VL-base)       | English |    150M    |    299 MB    |          Light weight multimodel embedder among image and text                   |  CLIP-base   |\n",
    "| [BAAI/bge-vl-large](https://huggingface.co/BAAI/BGE-VL-large)     | English |    428M    |    855 MB    |          Large scale multimodel embedder among image and text                    |  CLIP-large  |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BGE-VL-base and BGE-VL-large are trained based on CLIP base and CLIP large, which both contain a vision transformer and a text transformer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
      "  warnings.warn(\n",
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CLIPModel(\n",
       "  (text_model): CLIPTextTransformer(\n",
       "    (embeddings): CLIPTextEmbeddings(\n",
       "      (token_embedding): Embedding(49408, 512)\n",
       "      (position_embedding): Embedding(77, 512)\n",
       "    )\n",
       "    (encoder): CLIPEncoder(\n",
       "      (layers): ModuleList(\n",
       "        (0-11): 12 x CLIPEncoderLayer(\n",
       "          (self_attn): CLIPSdpaAttention(\n",
       "            (k_proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (v_proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (q_proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          )\n",
       "          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): CLIPMLP(\n",
       "            (activation_fn): QuickGELUActivation()\n",
       "            (fc1): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (fc2): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          )\n",
       "          (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (vision_model): CLIPVisionTransformer(\n",
       "    (embeddings): CLIPVisionEmbeddings(\n",
       "      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)\n",
       "      (position_embedding): Embedding(197, 768)\n",
       "    )\n",
       "    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "    (encoder): CLIPEncoder(\n",
       "      (layers): ModuleList(\n",
       "        (0-11): 12 x CLIPEncoderLayer(\n",
       "          (self_attn): CLIPSdpaAttention(\n",
       "            (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "          )\n",
       "          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): CLIPMLP(\n",
       "            (activation_fn): QuickGELUActivation()\n",
       "            (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          )\n",
       "          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (visual_projection): Linear(in_features=768, out_features=512, bias=False)\n",
       "  (text_projection): Linear(in_features=512, out_features=512, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from transformers import AutoModel\n",
    "\n",
    "MODEL_NAME = \"BAAI/BGE-VL-base\" # or \"BAAI/BGE-VL-base\"\n",
    "\n",
    "model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) # You must set trust_remote_code=True\n",
    "model.set_processor(MODEL_NAME)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2647, 0.1242]])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    query = model.encode(\n",
    "        images = \"../../imgs/cir_query.png\", \n",
    "        text = \"Make the background dark, as if the camera has taken the photo at night\"\n",
    "    )\n",
    "\n",
    "    candidates = model.encode(\n",
    "        images = [\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"]\n",
    "    )\n",
    "    \n",
    "    scores = query @ candidates.T\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. BGE-VL-MLLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |   Model Size   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
    "| [BAAI/bge-vl-MLLM-S1](https://huggingface.co/BAAI/BGE-VL-MLLM-S1)       | English |    7.57B    |    15.14 GB    |   SOTA in composed image retrieval, trained on MegaPairs dataset      |  LLaVA-1.6   |\n",
    "| [BAAI/bge-vl-MLLM-S2](https://huggingface.co/BAAI/BGE-VL-MLLM-S2)       | English |    7.57B    |    15.14 GB    |   Finetune BGE-VL-MLLM-S1 with one epoch on MMEB training set         |  LLaVA-1.6   |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/_distutils_hack/__init__.py:54: UserWarning: Reliance on distutils from stdlib is deprecated. Users must rely on setuptools to provide the distutils module. Avoid importing distutils or import setuptools first, and avoid setting SETUPTOOLS_USE_DISTUTILS=stdlib. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml\n",
      "  warnings.warn(\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.28it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LLaVANextForEmbedding(\n",
       "  (vision_tower): CLIPVisionModel(\n",
       "    (vision_model): CLIPVisionTransformer(\n",
       "      (embeddings): CLIPVisionEmbeddings(\n",
       "        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
       "        (position_embedding): Embedding(577, 1024)\n",
       "      )\n",
       "      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "      (encoder): CLIPEncoder(\n",
       "        (layers): ModuleList(\n",
       "          (0-23): 24 x CLIPEncoderLayer(\n",
       "            (self_attn): CLIPSdpaAttention(\n",
       "              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            )\n",
       "            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "            (mlp): CLIPMLP(\n",
       "              (activation_fn): QuickGELUActivation()\n",
       "              (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "              (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "            )\n",
       "            (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "  )\n",
       "  (multi_modal_projector): LlavaNextMultiModalProjector(\n",
       "    (linear_1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "    (act): GELUActivation()\n",
       "    (linear_2): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "  )\n",
       "  (language_model): MistralForCausalLM(\n",
       "    (model): MistralModel(\n",
       "      (embed_tokens): Embedding(32005, 4096)\n",
       "      (layers): ModuleList(\n",
       "        (0-31): 32 x MistralDecoderLayer(\n",
       "          (self_attn): MistralSdpaAttention(\n",
       "            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "            (rotary_emb): MistralRotaryEmbedding()\n",
       "          )\n",
       "          (mlp): MistralMLP(\n",
       "            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "            (act_fn): SiLU()\n",
       "          )\n",
       "          (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n",
       "          (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n",
       "        )\n",
       "      )\n",
       "      (norm): MistralRMSNorm((4096,), eps=1e-05)\n",
       "    )\n",
       "    (lm_head): Linear(in_features=4096, out_features=32005, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModel\n",
    "from PIL import Image\n",
    "\n",
    "MODEL_NAME= \"BAAI/BGE-VL-MLLM-S1\"\n",
    "\n",
    "model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
    "model.eval()\n",
    "model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.4109, 0.1807]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    model.set_processor(MODEL_NAME)\n",
    "\n",
    "    query_inputs = model.data_process(\n",
    "        text=\"Make the background dark, as if the camera has taken the photo at night\", \n",
    "        images=\"../../imgs/cir_query.png\",\n",
    "        q_or_c=\"q\",\n",
    "        task_instruction=\"Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: \"\n",
    "    )\n",
    "\n",
    "    candidate_inputs = model.data_process(\n",
    "        images=[\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"],\n",
    "        q_or_c=\"c\",\n",
    "    )\n",
    "\n",
    "    query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]\n",
    "    candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]\n",
    "    \n",
    "    query_embs = torch.nn.functional.normalize(query_embs, dim=-1)\n",
    "    candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)\n",
    "\n",
    "    scores = torch.matmul(query_embs, candi_embs.T)\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. BGE-VL-v1.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BGE-VL-v1.5 series is a new version of BGE-VL, bringing better performance on both retrieval and multi-modal understanding. It is trained on 30M MegaPairs data and extra 10M natural and synthetic data.\n",
    "\n",
    "`bge-vl-v1.5-zs` is a zero-shot model, only trained on the data mentioned above. `bge-vl-v1.5-mmeb` is the fine-tuned version on MMEB training set."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |   Model Size   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
    "| [BAAI/BGE-VL-v1.5-zs](https://huggingface.co/BAAI/BGE-VL-v1.5-zs)       | English |    7.57B    |    15.14 GB    |    Better multi-modal retrieval model with performs well in all kinds of tasks    |  LLaVA-1.6   |\n",
    "| [BAAI/BGE-VL-v1.5-mmeb](https://huggingface.co/BAAI/BGE-VL-v1.5-mmeb)       | English |    7.57B    |    15.14 GB    |    Better multi-modal retrieval model, additionally fine-tuned on MMEB training set    |  LLaVA-1.6   |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use BGE-VL-v1.5 models in the exact same way as BGE-VL-MLLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.3880, 0.1815]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModel\n",
    "from PIL import Image\n",
    "\n",
    "MODEL_NAME= \"BAAI/BGE-VL-v1.5-mmeb\" # \"BAAI/BGE-VL-v1.5-zs\"\n",
    "\n",
    "model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
    "model.eval()\n",
    "model.cuda()\n",
    "\n",
    "with torch.no_grad():\n",
    "    model.set_processor(MODEL_NAME)\n",
    "\n",
    "    query_inputs = model.data_process(\n",
    "        text=\"Make the background dark, as if the camera has taken the photo at night\", \n",
    "        images=\"../../imgs/cir_query.png\",\n",
    "        q_or_c=\"q\",\n",
    "        task_instruction=\"Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: \"\n",
    "    )\n",
    "\n",
    "    candidate_inputs = model.data_process(\n",
    "        images=[\"../../imgs/cir_candi_1.png\", \"../../imgs/cir_candi_2.png\"],\n",
    "        q_or_c=\"c\",\n",
    "    )\n",
    "\n",
    "    query_embs = model(**query_inputs, output_hidden_states=True)[:, -1, :]\n",
    "    candi_embs = model(**candidate_inputs, output_hidden_states=True)[:, -1, :]\n",
    "    \n",
    "    query_embs = torch.nn.functional.normalize(query_embs, dim=-1)\n",
    "    candi_embs = torch.nn.functional.normalize(candi_embs, dim=-1)\n",
    "\n",
    "    scores = torch.matmul(query_embs, candi_embs.T)\n",
    "print(scores)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
